Skip to content

Implement vectorized jacobian and allow arbitrary expression shapes #1228

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 20, 2025

Seeing about a 3.5x, (2x for larger x) time speedup for this trivial case:

from pytensor.gradient import jacobian
from pytensor import function
import pytensor.tensor as pt

x = pt.vector("x", shape=(3,))
y = pt.outer(x, x)
print(y.type.shape)

jac_y = jacobian(y, x, vectorize=False)
print(jac_y.type.shape)

fn = function([x], jac_y, profile=True)
fn.dprint(print_type=True)

%timeit fn([0, 1, 2])
fn([0, 1, 2])

Memory footprint will grow though, specially if intermediate operations are much larger than the final jacobian. Also not all graphs will be safely vectorizable, so I would leave the Scan option as a default for a while.

Related Issue

@ricardoV94 ricardoV94 force-pushed the vectorized_jacobian branch from 17ec0cd to 3c6ba6a Compare June 11, 2025 11:18
@ricardoV94 ricardoV94 force-pushed the vectorized_jacobian branch from 3c6ba6a to ff732d6 Compare June 11, 2025 11:20
@ricardoV94 ricardoV94 marked this pull request as ready for review June 11, 2025 11:20
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements an enhanced jacobian computation that supports vectorization and arbitrary expression shapes, yielding improved performance while preserving compatibility via a fallback scan method. Key changes include:

  • Adding a new vectorized branch to the jacobian function in pytensor/gradient.py along with new tests covering scalar, vector, and matrix cases.
  • Updating tests in tests/test_gradient.py to use a parameterized test class for the new vectorize functionality.
  • Minor type hint and function signature adjustments in pytensor/tensor/basic.py and pytensor/graph/replace.py.

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated no comments.

File Description
tests/test_gradient.py New parameterized tests for the jacobian function with the added vectorize flag.
pytensor/tensor/basic.py Added an early return in flatten when the reshaping yields the same number of dimensions.
pytensor/graph/replace.py Updated return type annotations for enhanced clarity and consistency.
pytensor/gradient.py Extended the jacobian function with a vectorize branch and adjusted inner gradient handling.
Comments suppressed due to low confidence (3)

pytensor/gradient.py:2094

  • Using zip with strict=True requires Python 3.10 or later. Please confirm that this requirement is acceptable for the project or consider alternative implementations for compatibility.
for i, (jacobian_single_row, jacobian_matrix) in enumerate(zip(jacobian_single_rows, jacobian_matrices, strict=True)):

pytensor/gradient.py:2104

  • [nitpick] In the non-vectorized branch of jacobian, the inner function unpacks arguments to compute grad(expr[idx], wrt, **grad_kwargs). Please verify that passing 'wrt' as a list directly matches the intended grad() API to avoid potential issues with multidimensional expressions.
idx, expr, *wrt = args

pytensor/gradient.py:2322

  • The change from g_out.zeros_like() to g_out.zeros_like(g_out) is unexpected compared to previous usage. Please ensure that the new call properly infers the shape and dtype without introducing recursion or inconsistency.
return [g_out.zeros_like(g_out) for g_out in g_outs]

@ricardoV94 ricardoV94 force-pushed the vectorized_jacobian branch 4 times, most recently from 1764707 to e2a2665 Compare June 11, 2025 11:43
@ricardoV94 ricardoV94 force-pushed the vectorized_jacobian branch from e2a2665 to 7f161ae Compare June 11, 2025 12:03
@ricardoV94
Copy link
Member Author

We can follow up with a vectorized Hessian, even though @aseyboldt doesn't believe anyone ever needs them. Allowing arbitrary expression shapes is not as trivial. JAX returns some nested tuple stuff...

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to add a benchmark of scan vs vectorize?

respect to some parameter ``x`` we need to use `scan`. What we
do is to loop over the entries in ``y`` and compute the gradient of
respect to some parameter ``x`` we can use `scan`.
We loop over the entries in ``y`` and compute the gradient of
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We loop over the entries in ``y`` and compute the gradient of
In this case, we loop over the entries in ``y`` and compute the gradient of

``y[i]`` with respect to ``x``.

.. note::

`scan` is a generic op in PyTensor that allows writing in a symbolic
manner all kinds of recurrent equations. While creating
symbolic loops (and optimizing them for performance) is a hard task,
effort is being done for improving the performance of `scan`. We
shall return to :ref:`scan<tutloop>` later in this tutorial.
effort is being done for improving the performance of `scan`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
effort is being done for improving the performance of `scan`.
efforts are being made to improving the performance of `scan`.

>>> from pytensor.graph import vectorize_graph
>>> x = pt.dvector('x')
>>> y = x ** 2
>>> row_tangent = pt.dvector("row_tangent") # Helper variable, it will be replaced during vectorization
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The term I think gets used is cotangent_vector ?

@@ -2051,62 +2057,73 @@ def jacobian(expression, wrt, consider_constant=None, disconnected_inputs="raise
output, then a zero variable is returned. The return value is
of same type as `wrt`: a list/tuple or TensorVariable in all cases.
"""
from pytensor.tensor import broadcast_to, eye
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the actual source for the import? tensor.basic or tensor.shape I guess?

)

amat = matrix()
amat_val = random(4, 5)
for ndim in (2, 1):
for ndim in (1,):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove useless loop

Copy link

codecov bot commented Jun 12, 2025

Codecov Report

Attention: Patch coverage is 89.28571% with 3 lines in your changes missing coverage. Please review.

Project coverage is 82.03%. Comparing base (5f5be92) to head (7f161ae).

Files with missing lines Patch % Lines
pytensor/gradient.py 88.46% 1 Missing and 2 partials ⚠️

❌ Your patch check has failed because the patch coverage (89.28%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1228      +/-   ##
==========================================
- Coverage   82.03%   82.03%   -0.01%     
==========================================
  Files         214      214              
  Lines       50398    50408      +10     
  Branches     8897     8902       +5     
==========================================
+ Hits        41345    41352       +7     
- Misses       6848     6850       +2     
- Partials     2205     2206       +1     
Files with missing lines Coverage Δ
pytensor/graph/replace.py 84.21% <ø> (ø)
pytensor/tensor/basic.py 91.70% <100.00%> (+0.01%) ⬆️
pytensor/gradient.py 77.86% <88.46%> (-0.69%) ⬇️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants